{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import re\n",
    "from copy import deepcopy\n",
    "\n",
    "import random\n",
    "import numpy as np\n",
    "\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = ''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from model import GPT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_path = '.'\n",
    "ckpt_reader = tf.train.load_checkpoint(\n",
    "        os.path.join(model_path, 'model.ckpt-220000'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "weights = list(ckpt_reader.get_variable_to_dtype_map().keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# batch_size, seq_len, context_size\n",
    "embedding_size = 1536\n",
    "B, L, C = [4, 1024, embedding_size]\n",
    "vocab_size = 8021\n",
    "max_position = 1024\n",
    "attention_head = 24\n",
    "attention_dropout = 0.1\n",
    "residual_dropout = 0.1\n",
    "layer_size = 48\n",
    "embedding_dropout = 0.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def GPTModel(**kwargs):\n",
    "    input_ids = tf.keras.layers.Input(\n",
    "        shape=(None, ),\n",
    "        name='input_ids',\n",
    "        dtype=tf.int64\n",
    "    )\n",
    "    out = GPT(**kwargs)(input_ids)\n",
    "    model = tf.keras.Model(\n",
    "        inputs=input_ids,\n",
    "        outputs=out)\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# gpt = GPTModel(\n",
    "#     vocab_size=vocab_size,\n",
    "#     layer_size=layer_size,\n",
    "#     block_size=max_position,\n",
    "#     embedding_dropout=embedding_dropout,\n",
    "#     embedding_size=C,\n",
    "#     num_attention_heads=attention_head,\n",
    "#     attention_dropout=attention_dropout,\n",
    "#     residual_dropout=residual_dropout\n",
    "# )\n",
    "gpt = GPT(\n",
    "    vocab_size=vocab_size,\n",
    "    layer_size=layer_size,\n",
    "    block_size=max_position,\n",
    "    embedding_dropout=embedding_dropout,\n",
    "    embedding_size=C,\n",
    "    num_attention_heads=attention_head,\n",
    "    attention_dropout=attention_dropout,\n",
    "    residual_dropout=residual_dropout\n",
    ")\n",
    "gpt._set_inputs(\n",
    "    tf.keras.backend.placeholder((None, None), dtype=tf.int64))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "gpt.build(tf.TensorShape([None, None]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "772\n"
     ]
    }
   ],
   "source": [
    "print(len(gpt.weights))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['embedding/embeddings:0',\n",
       " 'position_embeddings:0',\n",
       " 'LayerNorm_embed_norm/gamma:0',\n",
       " 'LayerNorm_embed_norm/beta:0',\n",
       " 'layer00/attention/query_layer/kernel:0',\n",
       " 'layer00/attention/query_layer/bias:0',\n",
       " 'layer00/attention/key_layer/kernel:0',\n",
       " 'layer00/attention/key_layer/bias:0',\n",
       " 'layer00/attention/value_layer/kernel:0',\n",
       " 'layer00/attention/value_layer/bias:0']"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "[x.name for x in gpt.weights[:10]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "name_shape = {}\n",
    "names = []\n",
    "for w in gpt.weights:\n",
    "    n = w.name[:-2]\n",
    "    name_shape[n] = w.shape\n",
    "    names.append(n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "772\n"
     ]
    }
   ],
   "source": [
    "print(len(names))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_layer = sorted([\n",
    "    x for x in weights\n",
    "    if 'adafactor_' not in x\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "773\n"
     ]
    }
   ],
   "source": [
    "print(len(all_layer))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "global_step () False\n"
     ]
    }
   ],
   "source": [
    "copy_name_shape = deepcopy(name_shape)\n",
    "output = {}\n",
    "for x in all_layer:\n",
    "    v = ckpt_reader.get_tensor(x)\n",
    "    # x = x.replace('newslm/', 'gpt/')\n",
    "    x = x.replace('newslm/', '')\n",
    "    x = x.replace('context_projection_layer', 'attention/context_projection_layer')\n",
    "    x = x.replace('key_layer', 'attention/key_layer')\n",
    "    x = x.replace('query_layer', 'attention/query_layer')\n",
    "    x = x.replace('value_layer', 'attention/value_layer')\n",
    "    x = x.replace('embeddings/LayerNorm_embed_norm', 'LayerNorm_embed_norm')\n",
    "    x = x.replace('embeddings/pos_embed', 'position_embeddings')\n",
    "    x = x.replace('embeddings/word_embed', 'embedding/embeddings')\n",
    "    \n",
    "    exists = x in copy_name_shape and v.shape == copy_name_shape[x]\n",
    "    if exists:\n",
    "        output[x] = v\n",
    "        del copy_name_shape[x]\n",
    "    else:\n",
    "        print(x, v.shape, exists)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "gpt.set_weights([output[w.name[:-2]] for w in gpt.weights])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 输入 [687, 1646, 1646, 3134, 581, 5774]\n",
    "# 输出 [1646, 2741, 3134, 2157, 5774, 6294]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tokenizers import BertWordPieceTokenizer\n",
    "tokenizer = BertWordPieceTokenizer(\n",
    "    './clue-vocab.txt',\n",
    "    lowercase=True,\n",
    "    add_special_tokens=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "def batch_gather(a, b):\n",
    "    return tf.gather(a, b, batch_dims=1)\n",
    "\n",
    "\n",
    "def top_p_sample(logits, num_samples=1, p=0.95):\n",
    "    batch_size, vocab_size = logits.shape\n",
    "    probs = tf.nn.softmax(logits, axis=-1)\n",
    "    # [batch_size, vocab_perm]\n",
    "    indices = tf.argsort(probs, direction='DESCENDING')\n",
    "    cumulative_probabilities = tf.math.cumsum(batch_gather(probs, indices), axis=-1, exclusive=False)\n",
    "\n",
    "    # find the top pth index to cut off. careful we don't want to cutoff everything!\n",
    "    # result will be [batch_size, vocab_perm]\n",
    "    p_expanded = p if isinstance(p, float) else p[:, None]\n",
    "    exclude_mask = tf.logical_not(\n",
    "        tf.logical_or(cumulative_probabilities < p_expanded, tf.range(vocab_size)[None] < 1))\n",
    "\n",
    "    # OPTION A - sample in the sorted space, then unsort.\n",
    "    logits_to_use = batch_gather(logits, indices) - tf.cast(exclude_mask, tf.float32) * 1e10\n",
    "    sample_perm = tf.random.categorical(logits=logits_to_use, num_samples=num_samples)\n",
    "    sample = batch_gather(indices, sample_perm)\n",
    "\n",
    "    return tf.cast(sample, tf.int64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "# def sample(tokenizer, gpt, sentence,\n",
    "#            number=5, length=15, layer_size=48,\n",
    "#            embedding_size=1536, attention_head=24):\n",
    "\n",
    "#     tokens = tokenizer.encode(sentence).ids\n",
    "#     i = tf.constant(0, dtype=tf.int64)\n",
    "#     initial_inputs = tf.constant([tokens] * number, dtype=tf.int64)\n",
    "#     initial_logits, kv_cache = gpt(initial_inputs, use_cache=True)\n",
    "#     inputs = top_p_sample(initial_logits[:, -1, :])\n",
    "#     stores = tf.concat([initial_inputs, inputs], axis=1)\n",
    "    \n",
    "#     def _cond(i, inputs, kv_cache, stores):\n",
    "#         return i < length\n",
    "\n",
    "#     def _body(i, inputs, kv_cache, stores):\n",
    "#         new_logits, new_kv_cache = gpt(inputs, kv_cache=kv_cache, use_cache=True)\n",
    "#         new_inputs = top_p_sample(new_logits[:, -1, :])\n",
    "#         new_stores = tf.concat([stores, new_inputs], axis=-1)\n",
    "#         new_kv_cache = tf.concat([\n",
    "#             kv_cache,\n",
    "#             new_kv_cache\n",
    "#         ], axis=-2)\n",
    "#         new_i = i + 1\n",
    "#         return [new_i, new_inputs, new_kv_cache, new_stores]\n",
    "\n",
    "#     result = tf.while_loop(\n",
    "#         _cond, _body,\n",
    "#         loop_vars=[i, inputs, kv_cache, stores],\n",
    "#         shape_invariants=[\n",
    "#             tf.TensorShape(None),\n",
    "#             tf.TensorShape([number, None]),\n",
    "#             tf.TensorShape([\n",
    "#                 layer_size, number, 2,\n",
    "#                 attention_head, None,\n",
    "#                 embedding_size // attention_head\n",
    "#             ]),\n",
    "#             tf.TensorShape([\n",
    "#                 number, None\n",
    "#             ])\n",
    "#         ]\n",
    "#     )\n",
    "#     return result[-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ret = sample(tokenizer, gpt, '生活总是要继续', length=50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "# for x in ret.numpy():\n",
    "#     print(tokenizer.decode(x))\n",
    "#     print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tf.function\n",
    "def serve(inputs):\n",
    "    return gpt(inputs, kv_cache=None, use_cache=True)\n",
    "\n",
    "\n",
    "@tf.function\n",
    "def serve_cache(inputs, kv_cache):\n",
    "    return gpt(inputs, kv_cache=kv_cache, use_cache=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "serve_concrete = serve.get_concrete_function(\n",
    "    tf.TensorSpec(shape=[None, None], dtype=tf.int64, name=\"inp\")\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "serve_cache_concrete = serve_cache.get_concrete_function(\n",
    "    tf.TensorSpec(shape=[None, None], dtype=tf.int64, name=\"inp\"),\n",
    "    tf.TensorSpec(shape=[\n",
    "        layer_size, None, 2, attention_head,\n",
    "        None, embedding_size // attention_head\n",
    "    ], dtype=tf.float32, name=\"kv_cache\")\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "r = serve_concrete(\n",
    "    tf.constant([[1]], tf.int64)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1, 1, 8021) (48, 1, 2, 24, 1, 64)\n"
     ]
    }
   ],
   "source": [
    "print(r[0].shape, r[1].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "r2 = serve_cache_concrete(\n",
    "    tf.constant([[1]], tf.int64),\n",
    "    r[1]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1, 1, 8021) (48, 1, 2, 24, 1, 64)\n"
     ]
    }
   ],
   "source": [
    "print(r2[0].shape, r2[1].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tf.function\n",
    "def sample(initial_inputs, length):\n",
    "    layer_size = 48\n",
    "    embedding_size = 1536\n",
    "    attention_head = 24\n",
    "\n",
    "    i = tf.constant(0, dtype=tf.int64)\n",
    "    initial_logits, kv_cache = serve(initial_inputs)\n",
    "    inputs = top_p_sample(initial_logits[:, -1, :])\n",
    "    stores = tf.concat([initial_inputs, inputs], axis=1)\n",
    "\n",
    "    def _cond(i, inputs, kv_cache, stores):\n",
    "        return i < length\n",
    "\n",
    "    def _body(i, inputs, kv_cache, stores):\n",
    "        new_logits, new_kv_cache = serve_cache(inputs, kv_cache)\n",
    "        \n",
    "        new_inputs = top_p_sample(new_logits[:, -1, :])\n",
    "        new_stores = tf.concat([stores, new_inputs], axis=-1)\n",
    "        new_kv_cache = tf.concat([\n",
    "            kv_cache,\n",
    "            new_kv_cache\n",
    "        ], axis=-2)\n",
    "        new_i = i + 1\n",
    "        return [new_i, new_inputs, new_kv_cache, new_stores]\n",
    "\n",
    "    result = tf.while_loop(\n",
    "        _cond, _body,\n",
    "        loop_vars=[i, inputs, kv_cache, stores],\n",
    "        shape_invariants=[\n",
    "            tf.TensorShape(None),\n",
    "            tf.TensorShape([None, None]),\n",
    "            tf.TensorShape([\n",
    "                layer_size, None, 2,\n",
    "                attention_head, None,\n",
    "                embedding_size // attention_head\n",
    "            ]),\n",
    "            tf.TensorShape([\n",
    "                None, None\n",
    "            ])\n",
    "        ]\n",
    "    )\n",
    "    return result[-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokens = tokenizer.encode('天气不错').ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(1, 4), dtype=int64, numpy=array([[1646, 3134,  581, 5774]])>"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.constant([tokens], dtype=tf.int64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(\n",
      "[[1646 3134  581 5774 6294 2302 1153 3290 5494 3684 6294 2377  705 1181\n",
      "   579 2805 1160 6294 4384 2863]], shape=(1, 20), dtype=int64)\n",
      "天 气 不 错 ， 想 去 海 边 玩 ， 所 以 叫 上 朋 友 ， 结 果\n"
     ]
    }
   ],
   "source": [
    "ret = sample(\n",
    "    tf.constant([tokens], dtype=tf.int64),\n",
    "    tf.constant(15, dtype=tf.int64)\n",
    ")\n",
    "print(ret)\n",
    "for s in ret.numpy():\n",
    "    print(tokenizer.decode(s))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(\n",
      "[[1646 3134  581 5774 6294 1893 1840 5306  708 3261 3241]\n",
      " [1646 3134  581 5774 6294 2804  660  678 2413 5507 1047]], shape=(2, 11), dtype=int64)\n",
      "天 气 不 错 ， 小 宝 贝 们 活 泼\n",
      "天 气 不 错 ， 有 些 人 把 运 动\n"
     ]
    }
   ],
   "source": [
    "ret = sample(\n",
    "    tf.constant([tokens, tokens], dtype=tf.int64),\n",
    "    tf.constant(6, dtype=tf.int64)\n",
    ")\n",
    "print(ret)\n",
    "for s in ret.numpy():\n",
    "    print(tokenizer.decode(s))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/qhduan/.local/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "This property should not be used in TensorFlow 2.0, as updates are applied automatically.\n",
      "WARNING:tensorflow:From /home/qhduan/.local/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "This property should not be used in TensorFlow 2.0, as updates are applied automatically.\n",
      "INFO:tensorflow:Assets written to: ./gpt_model_tf2/assets\n"
     ]
    }
   ],
   "source": [
    "# gpt.save('./gpt_model_tf2', include_optimizer=False, signatures={\n",
    "#     'serving_default': serve_concrete,\n",
    "#     'serving_cache': serve_cache_concrete\n",
    "# })\n",
    "\n",
    "# v = gpt.signatures['serving_default'](\n",
    "#     inp=tf.constant([\n",
    "#         [1,],\n",
    "#         [1,],\n",
    "#         [1,],\n",
    "#     ], dtype=tf.int64)\n",
    "# )\n",
    "# print(v['output_0'].shape, v['output_1'].shape)\n",
    "# v2 = gpt.signatures['serving_cache'](\n",
    "#     inp=tf.constant([\n",
    "#         [1,],\n",
    "#         [1,],\n",
    "#         [1,],\n",
    "#     ], dtype=tf.int64),\n",
    "#     kv_cache=v['output_1']\n",
    "# )\n",
    "# print(v2['output_0'].shape, v2['output_1'].shape)\n",
    "\n",
    "gpt.save('./gpt_model_tf2', include_optimizer=False, signatures={\n",
    "    'serving_default': sample.get_concrete_function(\n",
    "        tf.TensorSpec(shape=[None, None], dtype=tf.int64, name=\"inp\"),\n",
    "        tf.TensorSpec(shape=[None,], dtype=tf.int64, name=\"length\")\n",
    "    )\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
